"""
Phase 3: Mechanisms & Interactions — Real Exchange Rate
========================================================
Tests:
  (a) Z × NFA: net creditor/debtor asymmetry in REER response
  (b) Z × KAOPEN: capital account openness and REER adjustment
  (c) Z × trade openness: trade channel
  (d) Z × eurozone: constrained vs flexible ER regimes
  (e) Mediation: does Z affect REER through CA/NFA or directly?
  (f) Non-tradable channel (health expenditure proxy)
  (g) Chow test for structural break
  (h) Working-age share control

Output: output/tables/interactions.md, mediation.md, structural_break.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/rer")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        print(f"  [{label}] {dep_var} missing — skipping")
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, "
          f"R²={gls.r_squared:.4f}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<30} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")
    return results


def build_table(results, key_vars, notes, filename, title):
    if not results:
        return
    md = [f"# {title}\n"]
    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")
    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
    md.append(f"\n*{notes}*")
    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"\n  Saved: {out}")


def main():
    print("=" * 70)
    print("PHASE 3: Mechanisms & Interactions — Real Exchange Rate")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "rer_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls_bs = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_gdp_pc']
    controls_bs = [c for c in controls_bs if c in df.columns]

    # ═══════════════════════════════════════════════════════════════════
    # PART A: INTERACTION MODELS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: INTERACTION MODELS")
    print("=" * 70)

    int_results = []

    # A1: Z × NFA → log REER
    nfa_int = ['Z_1_x_nfa', 'Z_2_x_nfa', 'Z_3_x_nfa']
    nfa_int = [v for v in nfa_int if v in df.columns]
    r = run_model(df, 'log_reer_combined',
                  demo_vars + controls_bs + nfa_int,
                  "A1: Z×NFA → log REER")
    if r: int_results.append(r)

    # A2: Z × KAOPEN → log REER
    ka_int = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    ka_int = [v for v in ka_int if v in df.columns]
    r = run_model(df, 'log_reer_combined',
                  demo_vars + controls_bs + ka_int,
                  "A2: Z×KAOPEN → log REER")
    if r: int_results.append(r)

    # A3: Z × trade openness → log REER
    trade_int = ['Z_1_x_trade', 'Z_2_x_trade', 'Z_3_x_trade']
    trade_int = [v for v in trade_int if v in df.columns]
    r = run_model(df, 'log_reer_combined',
                  demo_vars + controls_bs + trade_int,
                  "A3: Z×trade → log REER")
    if r: int_results.append(r)

    # A4: Z × eurozone → log REER
    emu_int = ['Z_1_x_emu', 'Z_2_x_emu', 'Z_3_x_emu']
    emu_int = [v for v in emu_int if v in df.columns]
    r = run_model(df, 'log_reer_combined',
                  demo_vars + controls_bs + ['eurozone'] + emu_int,
                  "A4: Z×EMU → log REER")
    if r: int_results.append(r)

    # A5: NFA creditor vs debtor split
    if 'nfa_gdp_lag' in df.columns:
        creditor = df[df['nfa_gdp_lag'] >= 0].copy()
        debtor = df[df['nfa_gdp_lag'] < 0].copy()
        r = run_model(creditor, 'log_reer_combined', demo_vars + controls_bs,
                      "A5a: creditor Z → log REER")
        if r: int_results.append(r)
        r = run_model(debtor, 'log_reer_combined', demo_vars + controls_bs,
                      "A5b: debtor Z → log REER")
        if r: int_results.append(r)

    key_int_vars = demo_vars + nfa_int + ka_int + trade_int + emu_int + ['eurozone']
    build_table(int_results, key_int_vars,
                "Interaction models test how NFA, openness, trade, and EMU moderate "
                "the demographic effect on REER",
                "interactions.md",
                "Interaction Models: Demographics × Moderators → log(REER)")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: MEDIATION — NFA/CA channel
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: MEDIATION — Does Z affect REER through NFA or directly?")
    print("=" * 70)

    med_results = []
    controls_no_nfa = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'log_gdp_pc']
    controls_no_nfa = [c for c in controls_no_nfa if c in df.columns]

    # B1: Z → REER (no NFA control)
    r = run_model(df, 'log_reer_combined', demo_vars + controls_no_nfa,
                  "B1: Z → REER (no NFA)")
    if r: med_results.append(r)

    # B2: Z → REER (+ NFA)
    r = run_model(df, 'log_reer_combined', demo_vars + controls_bs,
                  "B2: Z → REER (+ NFA)")
    if r: med_results.append(r)

    # B3: Z → REER (+ ca_gdp)
    if 'ca_gdp' in df.columns:
        r = run_model(df, 'log_reer_combined', demo_vars + controls_bs + ['ca_gdp'],
                      "B3: Z → REER (+ NFA + CA)")
        if r: med_results.append(r)

    # B4: Z → REER (+ trade_openness)
    if 'trade_openness' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      demo_vars + controls_bs + ['trade_openness'],
                      "B4: Z → REER (+ NFA + trade)")
        if r: med_results.append(r)

    # B5: Z → REER (+ working_age_share — does WAS absorb the Z effect?)
    if 'working_age_share' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      demo_vars + controls_bs + ['working_age_share'],
                      "B5: Z → REER (+ WAS)")
        if r: med_results.append(r)

    # Compute attenuation
    if len(med_results) >= 2:
        z1_no_nfa = med_results[0].get('coef_Z_1', None)
        z1_with_nfa = med_results[1].get('coef_Z_1', None)
        if z1_no_nfa and z1_with_nfa and z1_no_nfa != 0:
            attenuation = (1 - z1_with_nfa / z1_no_nfa) * 100
            print(f"\n  ★ Z₁ attenuation (no NFA → + NFA): {attenuation:.1f}%")

    build_table(med_results,
                demo_vars + ['nfa_gdp_lag', 'ca_gdp', 'trade_openness', 'working_age_share'],
                "Sequential addition of external position controls",
                "mediation.md",
                "Mediation: External Position Channel vs Direct Demographic Effect")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: NON-TRADABLE CHANNEL
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: NON-TRADABLE CHANNEL")
    print("=" * 70)

    nt_results = []

    # C1: Z → health_exp (does aging raise non-tradable demand?)
    if 'health_exp_gdp' in df.columns:
        r = run_model(df, 'health_exp_gdp', demo_vars + ['rgdp_growth', 'log_gdp_pc'],
                      "C1: Z → health_exp_gdp")
        if r: nt_results.append(r)

    # C2: health_exp → REER (does non-tradable demand appreciate REER?)
    if 'health_exp_gdp' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      ['health_exp_gdp'] + controls_bs,
                      "C2: health_exp → log REER")
        if r: nt_results.append(r)

    # C3: Z → REER controlling for health (if health absorbs Z, channel confirmed)
    if 'health_exp_gdp' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      demo_vars + controls_bs + ['health_exp_gdp'],
                      "C3: Z → REER | health_exp")
        if r: nt_results.append(r)

    build_table(nt_results,
                demo_vars + ['health_exp_gdp'],
                "Non-tradable demand channel: aging → health spending → REER appreciation",
                "nontradable_channel.md",
                "Non-Tradable Channel: Demographics → Health Spending → REER")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: STRUCTURAL BREAKS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: STRUCTURAL BREAKS")
    print("=" * 70)

    break_results = []

    # Pre-GFC
    pre = df[df['year'] <= 2007].copy()
    r = run_model(pre, 'log_reer_combined', demo_vars + controls_bs,
                  "D1: pre-GFC")
    if r: break_results.append(r)

    # Post-GFC
    post = df[df['year'] >= 2008].copy()
    r = run_model(post, 'log_reer_combined', demo_vars + controls_bs,
                  "D2: post-GFC")
    if r: break_results.append(r)

    # Pre-2000
    pre2k = df[df['year'] <= 2000].copy()
    r = run_model(pre2k, 'log_reer_combined', demo_vars + controls_bs,
                  "D3: pre-2000")
    if r: break_results.append(r)

    # 2001-2024
    post2k = df[df['year'] >= 2001].copy()
    r = run_model(post2k, 'log_reer_combined', demo_vars + controls_bs,
                  "D4: 2001-2024")
    if r: break_results.append(r)

    # Chow test
    print("\n  Chow test (REER model) ...")
    vars_chow = [v for v in demo_vars + controls_bs if v in df.columns]
    full = df.dropna(subset=['log_reer_combined'] + vars_chow)
    pre_c = full[full['year'] <= 2007]
    post_c = full[full['year'] >= 2008]

    if len(pre_c) >= 50 and len(post_c) >= 50:
        gls_full = PanelGLS()
        gls_full.fit(full['log_reer_combined'].values, full[vars_chow].values,
                     full['iso3'].values, full['year'].values)
        rss_full = np.sum((full['log_reer_combined'].values -
                           full[vars_chow].values @ gls_full.beta) ** 2)

        gls_pre = PanelGLS()
        gls_pre.fit(pre_c['log_reer_combined'].values, pre_c[vars_chow].values,
                    pre_c['iso3'].values, pre_c['year'].values)
        rss_pre = np.sum((pre_c['log_reer_combined'].values -
                          pre_c[vars_chow].values @ gls_pre.beta) ** 2)

        gls_post = PanelGLS()
        gls_post.fit(post_c['log_reer_combined'].values, post_c[vars_chow].values,
                     post_c['iso3'].values, post_c['year'].values)
        rss_post = np.sum((post_c['log_reer_combined'].values -
                           post_c[vars_chow].values @ gls_post.beta) ** 2)

        k = len(vars_chow)
        n = len(full)
        F_chow = ((rss_full - rss_pre - rss_post) / k) / ((rss_pre + rss_post) / (n - 2 * k))
        p_chow = 1 - stats.f.cdf(F_chow, k, n - 2 * k)
        print(f"    Chow F = {F_chow:.2f}, p = {p_chow:.4f}")

    build_table(break_results, demo_vars + ['log_gdp_pc'],
                "Pre/post GFC and pre/post 2000 subsamples",
                "structural_break.md",
                "Structural Break Analysis")

    print("\n" + "=" * 70)
    print("Phase 3 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
